{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SBgdZAmvWrZQ"
   },
   "source": [
    "第三周作业二\n",
    "---\n",
    "\n",
    "1. 使用完整的 YelpReviewFull 数据集训练，对比看 Acc 最高能到多少。\n",
    "\n",
    "2. 加载本地保存的模型，进行评估和再训练更高的 F1 Score。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4843,
     "status": "ok",
     "timestamp": 1708048462985,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "OuER-lVWWevL",
    "outputId": "a69b7b2d-a7a4-429b-881a-d39dbc9ba083"
   },
   "outputs": [],
   "source": [
    "# from google.colab import drive\n",
    "# drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "3EIniHuxYYDI"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install datasets accelerate transformers torch>=2.1.2 evaluate scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/root/.cache/huggingface/datasets\n"
     ]
    }
   ],
   "source": [
    "from datasets import config\n",
    "print(config.HF_DATASETS_CACHE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6256,
     "status": "ok",
     "timestamp": 1708047664818,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "x7rv00f9Wxmf",
    "outputId": "001a47ce-f197-4e55-9beb-ccfd2fa37e2a",
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ca88fb2b9f145bca187f966381290aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/650000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d38a7f6991434eaab5fe6f8a9287fa08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/50000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"/root/autodl-tmp/yelp_review_full\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1708047664818,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "sDI2_zftYybJ",
    "outputId": "7b31640a-ac45-479b-cc36-f2d9e58c2e44"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['label', 'text'],\n",
       "        num_rows: 650000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['label', 'text'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "id": "yIl3DeR9Ymzo"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "\n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, datasets.ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1708047664818,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "pNo_xSeLYqY4",
    "outputId": "f32bcda7-95da-42cf-c6a6-f7a48f0354ab"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5 stars</td>\n",
       "      <td>I absolutely love this place. The coffee here is easily the best on the Mount. Friendly owner and great service and baked goods to die for!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3 stars</td>\n",
       "      <td>It's a used media store chain. Not to much to say about it. I usually shy away from these places because they run the independent stores out of business. I still stand by that but I gotta say that no matter which one of these I walk into, the people working here always know they're stuff and are really cool which makes me want to buy something from them. This one is no exception, I asked the guy if I could see something on the top shelf and the guy practically ran to the back room to get a ladder, now that's service!  I still have to knock some stars off for the low trade in value they give and the fact that their prices on some stuff is pretty steep for used. Still, the employees rock!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5 stars</td>\n",
       "      <td>From start to finish it was an amazing place !!!! Ok, I'll Admit every restaurant I go to it habit to rip it apart on its good and bad attributes. For the most part I tend to spot many flaws.  Last night I did not spot one flaw!! From the smell of the rug, to the well pressed linen, phenom lighting, attentive staff, polished silverware, fresh roses, great lead sommelier, and overall exceptional hospitality, this restaurant hit the mark! Upon being seated we wanted something more private and I literally pointed at a table and in 45 seconds it was ready to go!! You definitely get what you pay for!!!! \\n\\nThe food was AMAZING!! amazing ingredients.  Our start was a bottle of Pinot noir from Santa Barbara, with a sweet pea puree/wild mushrooms/duck prosciutto Amazing! next,  the Exec. Chef made a special rack of lamb for me which was absolutely amazing and cooked to perfection med rare!  Our waiter personally brought me into the kitchen and let me thank him personally! \\n\\nOverall: Great Experience!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1 star</td>\n",
       "      <td>We've been to the Garage several times (always brought other friends with us because it was close and we liked the atmosphere and the food was pretty good). Service was okay. Recently we chose to go there because we had purchased a groupon. Five of us ate dinner (food was okay to decent) and had two drinks a piece. The bill game to $125 without tip. The Garage deducted the $30 groupon, but we didn't notice until after we left that they tacked on an extra $16 fee. I went back to ask why the extra $16 fee? They said if you use the groupon during happy hour you lost all happy hour pricing. This wasn't mentioned at the time otherwise I wouldn't have used my groupon. They said it was on the groupon that it wasn't valid for happy hour (fine, then don't take it and by the way we also had five dinners). We didn't go there for happy hour, we went for dinner. At the end of the day, I spent $15 for the groupon, paid an extra $16 for the happy hour charge (total of $31 dollars) and received a $30 deduction. Not good math. My conclusion - don't go there with your groupon during happy hour, you might get screwed. Also, as a restaurant,  you would think you would want to make a return, local, recommending customer happy, instead of screw them. I won't be returning and I won't be bringing my friends either. We can go to Zipps, the Vig or Phoenix City Grille instead.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5 stars</td>\n",
       "      <td>Lobster Tacos...hands down amazing; it comes with 2 and I had to get a second helping of it because it was that good! Try it!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>4 stars</td>\n",
       "      <td>The show was really good and there are not really any bad seats. We enjoyed ourselves and would recommend seeing it.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2 star</td>\n",
       "      <td>I purchased a Groupon for two nights at HRH. The reservation process easy-peasy and I thought I got a steal with only paying $240 for two nights (Thursday/Friday; Feb 7-9). When I arrived to check in the front desk staff was rude and the attendant reprimanded me for not having my Groupon printed out. I showed her the purchased Groupon on my phone and it CLEARLY stated 'No Voucher Necessary'.\\n\\nOnce we were finally checked in, we went to our room on the 11th floor in the Casino Tower. Although we requested a non-smoking room, we opened the door and our senses were bombarded with the fumes of paint with hints of smoke. Apparently they had been working on repairs in the rooms, because the smell of paint was evident throughout the entire floor. Upon further appraisal of the room the d\\u00e9cor was fitting for the room, but the cleanliness was appalling and down right disgusting!! The bathroom was filthy, there was still remnants in the shower from the last guest and the sink and countertops still had heavy soap residue all over it. The toilet looked dirty, especially the back of the toilet (gross) and there were some questionable stains on the walls. I walked into the room and the beds looked clean, but there was stains on the night tables and vanity (clearly not wiped down at all). The floors looked like they hadn't seen a vacuum in a while! I called down to the \\\"friendly\\\" front desk staff asked for another room and I was told the hotel was booked due to an event and nothing else was available...that`s it! No apology, no offering to call housekeeping, nothing! I walked out of the room and tracked down a housekeeping attendant and asked her to freshen up the room. She was very annoyed, but got down to business.\\n\\nPros:\\nD\\u00e9cor is fitting for the hotel concept\\nDecent variety of eateries\\n\\nCons:\\nResort fee..REALLY!?!?! Why isn't this included in the room costs??? This fee is unnecessary and the \\\"perks\\\" for paying it is NOT worth it!\\nRude/lackluster customer service from the point of contact personnel (front desk staff)\\nRun down..at least in the Casino Towers, and a bit unsanitary (based on the room they were going to allow me to stay in)\\n\\nI`ll go elsewhere next time.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>2 star</td>\n",
       "      <td>I am currently at store. It's been an hr for a pickup of two pizzas. Only two people are working. One on register one making pizzas. When u have $2 pizzas you expect this to be a busy time. Why only one person? It is only $6 total with 2 pizzas and tax. But holy cow be aware you'll be busy. And this has nothing to do with the food because I haven't even touched it yet.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2 star</td>\n",
       "      <td>D\\u00e9sol\\u00e9 ! Je ne partage pas l'avis de mes confr\\u00e8res yelpeurs au sujet de cette chaine de chocolatier. Les prix sont ridiculement \\u00e9lev\\u00e9s malgr\\u00e9 la soit distante qualit\\u00e9 du produit, d'autant plus que les chocolats sont minuscules. \\n\\nJe suis pass\\u00e9 dans l'une de leurs succursales pour offrir un petit cadeau \\u00e0 ma conjointe. J'ai d\\u00fb cracher 12 $ pour une boite de quatre chocolats ! Pour ajouter l'insulte \\u00e0 l'injure, ma conjointe n'a pas aim\\u00e9 deux des quatre saveurs que j'avais choisies pour elle.    \\n\\nUne \\u00e9toile pour le design et la pr\\u00e9sentation des produits, une \\u00e9toile pour le service. C'est tout.\\n\\nPourtant, il semble que beaucoup de gens aiment. Ah snobisme! Quand tu nous tiens...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>3 stars</td>\n",
       "      <td>Came here immediately after our flight landed. Was here with a few friends on a recruiting trip and they had gone to school at CMU previously. \\n\\nWe attempted to eat and drink our way to our per diem -- to no avail!!\\n\\nThe 1/2 price menu had our bill for 4 people at 50$. Including massive margaritas, appetizers, and enough food to take home (which of course we didn't, who wants left-over mad mex?!?!).\\n\\nThe scene was pretty horrible tho. very very crowded, and a ton of drunk college kids. Yah, I know, I'm in a college town. Just reminds me of why I moved away from mine, :).</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(dataset[\"train\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "512"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenized_datasets[\"train\"][0]['token_type_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fa7b9f5b51c48b19b72d9268ed71b08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/650000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f632578b42c4d03807a6dcf1d480c8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/50000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/root/autodl-tmp/models/bert-base-cased\")\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True,max_length=512)\n",
    "    \n",
    "tokenized_datasets = dataset.map(tokenize_function, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WtqqUlYBZORZ"
   },
   "source": [
    "上述代码使用模型的tokenizer函数，对每个输入进行预处理。预处理包括针对text字段进行截断和填充。\n",
    "1. 为啥要进行截断？\n",
    "2. 为啥要进行填充？\n",
    "\n",
    "源于神经网络，神经网络是进行矩阵运算的，输入数据必须有维度。在进行文本向量化过程中，如果不满足维度要求，则短的进行填充，长的进行截断。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 324
    },
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1708047668563,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "YUzAXbXLdU66",
    "outputId": "8ff88d41-5470-46bf-897e-44846a50b138"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "      <th>input_ids</th>\n",
       "      <th>token_type_ids</th>\n",
       "      <th>attention_mask</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3 stars</td>\n",
       "      <td>If you go to Cave Creek you must go to the Buffalo Chip, Harold's , hideaway etc.  just great local restaurants in a great town north of Phoenix/Scottsdale.   I like the Chip more for drinking dancing and rodeo than the food. But it is okay for a little grub   But it is fun.  So try it</td>\n",
       "      <td>[101, 1409, 1128, 1301, 1106, 11900, 3063, 1128, 1538, 1301, 1106, 1103, 6911, 20379, 117, 6587, 112, 188, 117, 4750, 7138, 3576, 119, 1198, 1632, 1469, 7724, 1107, 170, 1632, 1411, 1564, 1104, 6343, 120, 2796, 20537, 119, 146, 1176, 1103, 20379, 1167, 1111, 5464, 5923, 1105, 8335, 1186, 1190, 1103, 2094, 119, 1252, 1122, 1110, 3008, 1111, 170, 1376, 176, 5082, 1830, 1252, 1122, 1110, 4106, 119, 1573, 2222, 1122, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]</td>\n",
       "      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 650000\n",
       "})"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_random_elements(tokenized_datasets[\"train\"], num_examples=1)\n",
    "tokenized_datasets[\"train\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "id": "UjS3-9sjg_25"
   },
   "outputs": [],
   "source": [
    "small_train_dataset = tokenized_datasets[\"train\"].shuffle(seed=42).select(range(1000))\n",
    "small_eval_dataset = tokenized_datasets[\"test\"].shuffle(seed=42).select(range(1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 1000\n",
       "})"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "small_train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 650000\n",
       "})"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets[\"train\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TK6r2uTbiNqD"
   },
   "source": [
    "## 1.微调bert模型\n",
    "\n",
    "有哪些需要注意的？\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2083,
     "status": "ok",
     "timestamp": 1708047671228,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "tvrZyXsbiIJy",
    "outputId": "d8d5fa3e-c6e2-4987-8667-23674d00ce47"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /root/autodl-tmp/models/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"/root/autodl-tmp/models/bert-base-cased\", num_labels=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "id": "CKD6j081i4ca"
   },
   "outputs": [],
   "source": [
    "\n",
    "from transformers import TrainingArguments\n",
    "\n",
    "model_dir = \"/root/autodl-tmp/ouput/models/bert-base-cased-finetune-yelp\"\n",
    "\n",
    "# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100\n",
    "training_args = TrainingArguments(output_dir=model_dir,\n",
    "                  per_device_train_batch_size=16,\n",
    "                  num_train_epochs=5,\n",
    "                  logging_steps=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qb5KQ78klEDC"
   },
   "source": [
    "## 2.指标评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "6b4548838d7c4ac9aa7de443509de918",
      "bfd64dbd2d054998af3ad520a4300249",
      "88692cfb98054e34afd4b2e4ca03c342",
      "762a04405c624c23a9358a0f5649c922",
      "e632aa7019e54d6883f126aea24c32d6",
      "8344381b241640bab591233868e5bdf7",
      "7070078e5d224e4892521d2f3a90b078",
      "e8d3e4ee45a54b8185ecafc76aa43ab4",
      "7292138df59144499f687a1967fc3c85",
      "1f29816cc56643ef9f6ff0d6b52053bb",
      "71d42f3426d641e6b0ac1b1e449218b7"
     ]
    },
    "executionInfo": {
     "elapsed": 13880,
     "status": "ok",
     "timestamp": 1708047899815,
     "user": {
      "displayName": "boy golang",
      "userId": "17669684625027249527"
     },
     "user_tz": -480
    },
    "id": "IcsggDp1k--p",
    "outputId": "f81f92a5-3e1f-423f-b5fb-c8f2a5e0dbec"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pgT2_4WQlsXr"
   },
   "source": [
    "## 3.训练参数中添加评估策略"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "id": "qKwO_xQQl0JL"
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(output_dir=model_dir,\n",
    "                  evaluation_strategy=\"epoch\", # epoch结束时，报告评估指标\n",
    "                  per_device_train_batch_size=16,\n",
    "                  num_train_epochs=3,  # 训练3轮\n",
    "                  logging_steps=30)   #"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yBJZqc7omKb5"
   },
   "source": [
    "## 4.进行微调训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "id": "dzs_M5ucmGwr"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.2+cu121\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='189' max='189' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [189/189 02:09, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.604600</td>\n",
       "      <td>1.522083</td>\n",
       "      <td>0.340000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.177100</td>\n",
       "      <td>1.035253</td>\n",
       "      <td>0.566000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.869800</td>\n",
       "      <td>0.965289</td>\n",
       "      <td>0.593000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=189, training_loss=1.2474889931855377, metrics={'train_runtime': 129.5584, 'train_samples_per_second': 23.156, 'train_steps_per_second': 1.459, 'total_flos': 789354427392000.0, 'train_loss': 1.2474889931855377, 'epoch': 3.0})"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "small_train_dataset = tokenized_datasets[\"train\"].shuffle(seed=42).select(range(1000))\n",
    "small_eval_dataset = tokenized_datasets[\"test\"].shuffle(seed=42).select(range(1000))\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "id": "KXZZjRmzmS05"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='13' max='13' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [13/13 00:01]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 1.0351089239120483,\n",
       " 'eval_accuracy': 0.49,\n",
       " 'eval_runtime': 1.5574,\n",
       " 'eval_samples_per_second': 64.208,\n",
       " 'eval_steps_per_second': 8.347,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "small_test_dataset = tokenized_datasets[\"test\"].shuffle(seed=64).select(range(100))\n",
    "trainer.evaluate(small_test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "id": "bBw06ARnqQUR"
   },
   "outputs": [],
   "source": [
    "trainer.save_model(model_dir)\n",
    "trainer.save_state()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.使用完整的 YelpReviewFull 数据集训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='501' max='121875' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [   501/121875 04:06 < 16:38:01, 2.03 it/s, Epoch 0.01/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "SafetensorError",
     "evalue": "Error while serializing: IoError(Os { code: 28, kind: StorageFull, message: \"No space left on device\" })",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mSafetensorError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[58], line 13\u001b[0m\n\u001b[1;32m      4\u001b[0m small_eval_dataset \u001b[38;5;241m=\u001b[39m tokenized_datasets[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m      6\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m      7\u001b[0m     model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m      8\u001b[0m     args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     11\u001b[0m     compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics,\n\u001b[1;32m     12\u001b[0m )\n\u001b[0;32m---> 13\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1555\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1553\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   1554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1555\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1556\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1557\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1558\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1559\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1560\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1922\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   1919\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m steps_skipped) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m   1920\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 1922\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1923\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1924\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2282\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2279\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlr_scheduler\u001b[38;5;241m.\u001b[39mstep(metrics[metric_to_check])\n\u001b[1;32m   2281\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 2282\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetrics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetrics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2283\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_save(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2350\u001b[0m, in \u001b[0;36mTrainer._save_checkpoint\u001b[0;34m(self, model, trial, metrics)\u001b[0m\n\u001b[1;32m   2348\u001b[0m run_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_output_dir(trial\u001b[38;5;241m=\u001b[39mtrial)\n\u001b[1;32m   2349\u001b[0m output_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(run_dir, checkpoint_folder)\n\u001b[0;32m-> 2350\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_internal_call\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m   2351\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_deepspeed_enabled:\n\u001b[1;32m   2352\u001b[0m     \u001b[38;5;66;03m# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed\u001b[39;00m\n\u001b[1;32m   2353\u001b[0m     \u001b[38;5;66;03m# config `stage3_gather_16bit_weights_on_model_save` is True\u001b[39;00m\n\u001b[1;32m   2354\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped\u001b[38;5;241m.\u001b[39msave_checkpoint(output_dir)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2843\u001b[0m, in \u001b[0;36mTrainer.save_model\u001b[0;34m(self, output_dir, _internal_call)\u001b[0m\n\u001b[1;32m   2840\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped\u001b[38;5;241m.\u001b[39msave_checkpoint(output_dir)\n\u001b[1;32m   2842\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 2843\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2845\u001b[0m \u001b[38;5;66;03m# Push to the Hub when `save_model` is called by the user.\u001b[39;00m\n\u001b[1;32m   2846\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpush_to_hub \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _internal_call:\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2901\u001b[0m, in \u001b[0;36mTrainer._save\u001b[0;34m(self, output_dir, state_dict)\u001b[0m\n\u001b[1;32m   2899\u001b[0m             torch\u001b[38;5;241m.\u001b[39msave(state_dict, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(output_dir, WEIGHTS_NAME))\n\u001b[1;32m   2900\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2901\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2902\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msafe_serialization\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_safetensors\u001b[49m\n\u001b[1;32m   2903\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2905\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   2906\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer\u001b[38;5;241m.\u001b[39msave_pretrained(output_dir)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/modeling_utils.py:2187\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m   2183\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m   2184\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m   2185\u001b[0m         \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m   2186\u001b[0m         \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2187\u001b[0m         \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2188\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2189\u001b[0m         save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m    250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m    251\u001b[0m     tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m    252\u001b[0m     filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m    253\u001b[0m     metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    254\u001b[0m ):\n\u001b[1;32m    255\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    256\u001b[0m \u001b[38;5;124;03m    Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m    257\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    279\u001b[0m \u001b[38;5;124;03m    ```\u001b[39;00m\n\u001b[1;32m    280\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m     \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mSafetensorError\u001b[0m: Error while serializing: IoError(Os { code: 28, kind: StorageFull, message: \"No space left on device\" })"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "small_train_dataset = tokenized_datasets[\"train\"]\n",
    "small_eval_dataset = tokenized_datasets[\"test\"]\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_test_dataset = tokenized_datasets[\"test\"].shuffle(seed=64).select(range(100))\n",
    "trainer.evaluate(small_test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(model_dir)\n",
    "trainer.save_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='87297' max='121875' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 87297/121875 13:06:10 < 5:11:24, 1.85 it/s, Epoch 2.15/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.812100</td>\n",
       "      <td>0.765708</td>\n",
       "      <td>0.665740</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.688000</td>\n",
       "      <td>0.742344</td>\n",
       "      <td>0.678560</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"/root/models/bert-base-cased-finetune-yelp/checkpoint-8500\", num_labels=5)\n",
    "model_dir = \"/root/autodl-tmp/ouput/models/bert-base-cased-finetune-yelp\"\n",
    "\n",
    "training_args = TrainingArguments(output_dir=model_dir,\n",
    "                  evaluation_strategy=\"epoch\", # epoch结束时，报告评估指标\n",
    "                  per_device_train_batch_size=16,\n",
    "                  num_train_epochs=3,  # 训练3轮\n",
    "                  logging_steps=30)   #\n",
    "\n",
    "small_train_dataset = tokenized_datasets[\"train\"]\n",
    "small_eval_dataset = tokenized_datasets[\"test\"]\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='121875' max='121875' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [121875/121875 2:29:25, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.592200</td>\n",
       "      <td>0.781156</td>\n",
       "      <td>0.684280</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=121875, training_loss=0.07394406636164738, metrics={'train_runtime': 8965.7596, 'train_samples_per_second': 217.494, 'train_steps_per_second': 13.593, 'total_flos': 5.130803778048e+17, 'train_loss': 0.07394406636164738, 'epoch': 3.0})"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"/root/models/bert-base-cased-finetune-yelp\", num_labels=5)\n",
    "model_dir = \"/root/autodl-tmp/ouput/models/bert-base-cased-finetune-yelp\"\n",
    "\n",
    "training_args = TrainingArguments(output_dir=model_dir,\n",
    "                  evaluation_strategy=\"epoch\", # epoch结束时，报告评估指标\n",
    "                  per_device_train_batch_size=16,\n",
    "                  num_train_epochs=3,  # 训练3轮\n",
    "                  logging_steps=30)   #\n",
    "\n",
    "small_train_dataset = tokenized_datasets[\"train\"]\n",
    "small_eval_dataset = tokenized_datasets[\"test\"]\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "trainer.train(resume_from_checkpoint=\"/root/autodl-tmp/ouput/models/bert-base-cased-finetune-yelp/checkpoint-105500\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(model_dir)\n",
    "trainer.save_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='13' max='13' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [13/13 00:01]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.8879196643829346,\n",
       " 'eval_accuracy': 0.63,\n",
       " 'eval_runtime': 1.2675,\n",
       " 'eval_samples_per_second': 78.898,\n",
       " 'eval_steps_per_second': 10.257,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "small_test_dataset = tokenized_datasets[\"test\"].shuffle(seed=64).select(range(100))\n",
    "trainer.evaluate(small_test_dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyN06reuc7t/8XwzNOz06Moo",
   "collapsed_sections": [
    "TK6r2uTbiNqD"
   ],
   "gpuType": "T4",
   "mount_file_id": "1KMaRZTD_43g9uZ5Ku0c_vCY9SfDJSX6t",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "1f29816cc56643ef9f6ff0d6b52053bb": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "6b4548838d7c4ac9aa7de443509de918": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_bfd64dbd2d054998af3ad520a4300249",
       "IPY_MODEL_88692cfb98054e34afd4b2e4ca03c342",
       "IPY_MODEL_762a04405c624c23a9358a0f5649c922"
      ],
      "layout": "IPY_MODEL_e632aa7019e54d6883f126aea24c32d6"
     }
    },
    "7070078e5d224e4892521d2f3a90b078": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "71d42f3426d641e6b0ac1b1e449218b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "7292138df59144499f687a1967fc3c85": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "762a04405c624c23a9358a0f5649c922": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1f29816cc56643ef9f6ff0d6b52053bb",
      "placeholder": "​",
      "style": "IPY_MODEL_71d42f3426d641e6b0ac1b1e449218b7",
      "value": " 4.20k/4.20k [00:00&lt;00:00, 100kB/s]"
     }
    },
    "8344381b241640bab591233868e5bdf7": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "88692cfb98054e34afd4b2e4ca03c342": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_e8d3e4ee45a54b8185ecafc76aa43ab4",
      "max": 4203,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7292138df59144499f687a1967fc3c85",
      "value": 4203
     }
    },
    "bfd64dbd2d054998af3ad520a4300249": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_8344381b241640bab591233868e5bdf7",
      "placeholder": "​",
      "style": "IPY_MODEL_7070078e5d224e4892521d2f3a90b078",
      "value": "Downloading builder script: 100%"
     }
    },
    "e632aa7019e54d6883f126aea24c32d6": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "e8d3e4ee45a54b8185ecafc76aa43ab4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
